#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#include <time.h>
#include <math.h>
#include <openssl/evp.h>

#include "rng.h"
#include "api.h"
#include "gmp.h"
#include "kaz_api.h"

void KAZ_KA_RANDOM(mpz_t lb, mpz_t ub, mpz_t out)
{
	mpz_t range, rand_in_range;
    gmp_randstate_t state;

    mpz_inits(range, rand_in_range, NULL);

    // Compute range = (max - min + 1)
    mpz_sub(range, ub, lb);
    mpz_add_ui(range, range, 1);

    // Initialize random generator
    gmp_randinit_default(state);
    gmp_randseed_ui(state, 123456789); //time(NULL)

    // Generate random number: lb ≤ rand_in_range < ub
    mpz_urandomm(rand_in_range, state, range);

    // result = lb + rand_in_range
    mpz_add(out, lb, rand_in_range);

    // Cleanup
    mpz_clears(range, rand_in_range, NULL);
    gmp_randclear(state);
}

int KAZ_KA_KEYGEN(unsigned char *kaz_ka_public_key, unsigned char *kaz_ka_private_key, const unsigned char *kaz_ka_types)
{
	int ret=0;
    mpz_t N, g1, g2, Og1N, Og2N, a1, a2, e1, e2;
    mpz_t tmp, lowerbound, upperbound;

    mpz_inits(N, g1, g2, Og1N, Og2N, a1, a2, e1, e2, NULL);
    mpz_inits(tmp, lowerbound, upperbound, NULL);

    //Get all system parameters and precomputed parameters
    mpz_set_str(N, KAZ_KA_SP_N, 10);
    mpz_set_str(g1, KAZ_KA_SP_g1, 10);
    mpz_set_str(g2, KAZ_KA_SP_g2, 10);
	mpz_set_str(Og1N, KAZ_KA_SP_Og1N, 10);
	mpz_set_str(Og2N, KAZ_KA_SP_Og2N, 10);

	if (kaz_ka_types == NULL){
		// Generate a1, a2 randomly
		mpz_ui_pow_ui(lowerbound, 2, KAZ_KA_SP_LOg1N-2);
		mpz_set(upperbound, Og1N);
		KAZ_KA_RANDOM(lowerbound, upperbound, a1);
		
		mpz_ui_pow_ui(lowerbound, 2, KAZ_KA_SP_LOg2N-2);
		mpz_set(upperbound, Og2N);
		KAZ_KA_RANDOM(lowerbound, upperbound, a2);
		
		// Compute e1
		mpz_powm(e1, g1, a1, N);
		mpz_powm(tmp, g2, a2, N);
		mpz_mul(e1, e1, tmp);
		mpz_mod(e1, e1, N);
		
		// Compute e2
		mpz_powm(e2, g1, a2, N);
		mpz_mul_ui(tmp, a1, 2);
		mpz_powm(tmp, g2, tmp, N);
		mpz_mul(e2, e2, tmp);
		mpz_mod(e2, e2, N);
	} else if (strcmp((const char *)kaz_ka_types, "PRIMARY") == 0){
		// Generate a1, a2 randomly
		mpz_ui_pow_ui(lowerbound, 2, KAZ_KA_SP_LOg1N-2);
		mpz_set(upperbound, Og1N);
		KAZ_KA_RANDOM(lowerbound, upperbound, a1);
		
		mpz_ui_pow_ui(lowerbound, 2, KAZ_KA_SP_LOg2N-2);
		mpz_set(upperbound, Og2N);
		KAZ_KA_RANDOM(lowerbound, upperbound, a2);
		
		// Compute e1
		mpz_powm(e1, g1, a1, N);
		mpz_mul_ui(tmp, a2, 2);
		mpz_powm(tmp, g2, tmp, N);
		mpz_mul(e1, e1, tmp);
		mpz_mod(e1, e1, N);
		
		// Compute e2
		mpz_powm(e2, g1, a2, N);
		mpz_powm(tmp, g2, a1, N);
		mpz_mul(e2, e2, tmp);
		mpz_mod(e2, e2, N);
	}
	
    // Set kaz_ka_public_key={e1, e2} & kaz_ka_private_key=(a1, a2)
    size_t E1SIZE=mpz_sizeinbase(e1, 16);
	size_t E2SIZE=mpz_sizeinbase(e2, 16);
	size_t a1SIZE=mpz_sizeinbase(a1, 16);
	size_t a2SIZE=mpz_sizeinbase(a2, 16);

	unsigned char *E1BYTE=NULL;
	unsigned char *E2BYTE=NULL;
	unsigned char *a1BYTE=NULL;
	unsigned char *a2BYTE=NULL;

	E1BYTE=(unsigned char*) malloc(E1SIZE*sizeof(unsigned char));
	E2BYTE=(unsigned char*) malloc(E2SIZE*sizeof(unsigned char));
	a1BYTE=(unsigned char*) malloc(a1SIZE*sizeof(unsigned char));
	a2BYTE=(unsigned char*) malloc(a2SIZE*sizeof(unsigned char));

	if (!E1BYTE || !E2BYTE || !a1BYTE || !a2BYTE) {
        fprintf(stderr, "KAZ-KA-KEYGEN: Memory allocation failed.\n");
		ret=-4;
        goto kaz_ka_cleanup;
    }
	
	mpz_export(E1BYTE, &E1SIZE, 1, sizeof(char), 0, 0, e1);
	mpz_export(E2BYTE, &E2SIZE, 1, sizeof(char), 0, 0, e2);
	mpz_export(a1BYTE, &a1SIZE, 1, sizeof(char), 0, 0, a1);
	mpz_export(a2BYTE, &a2SIZE, 1, sizeof(char), 0, 0, a2);

	memset(kaz_ka_public_key, 0, KAZ_KA_PUBLICKEY_BYTES*2);
	memset(kaz_ka_private_key, 0, KAZ_KA_PRIVATEKEY_BYTES*2);

	int je=(KAZ_KA_PUBLICKEY_BYTES*2)-1;
	
	for(int i=E2SIZE-1; i>=0; i--){
		kaz_ka_public_key[je]=E2BYTE[i];
		je--;
	}

	je=(KAZ_KA_PUBLICKEY_BYTES*2)-KAZ_KA_PUBLICKEY_BYTES-1;
	for(int i=E1SIZE-1; i>=0; i--){
		kaz_ka_public_key[je]=E1BYTE[i];
		je--;
	}

	je=(KAZ_KA_PRIVATEKEY_BYTES*2)-1;
	for(int i=a2SIZE-1; i>=0; i--){
		kaz_ka_private_key[je]=a2BYTE[i];
		je--;
	}

	je=(KAZ_KA_PRIVATEKEY_BYTES*2)-KAZ_KA_PRIVATEKEY_BYTES-1;
	for(int i=a1SIZE-1; i>=0; i--){
		kaz_ka_private_key[je]=a1BYTE[i];
		je--;
	}

	kaz_ka_cleanup:
		mpz_clears(N, g1, g2, Og1N, Og2N, a1, a2, e1, e2, NULL);
		mpz_clears(tmp, lowerbound, upperbound, NULL);
		free(E1BYTE);
		free(E2BYTE);
		free(a1BYTE);
		free(a2BYTE);

	return ret;
}

int KAZ_KEY_AGREEMENT(unsigned char *sharedkey, const unsigned char *pk, const unsigned char *sk)
{
	int ret=0;

    mpz_t N, e1, e2, a1, a2, SHARED, tmp;
    mpz_inits(N, e1, e2, a1, a2, SHARED, tmp, NULL);

    //Get all system parameters and precomputed parameters
    mpz_set_str(N, KAZ_KA_SP_N, 10);

	// Get kaz_ka_public_key_slave={e1, e2} and kaz_ka_private_key_primary={a1, a2}
	unsigned char *E1BYTE=NULL;
	unsigned char *E2BYTE=NULL;
	unsigned char *A1BYTE=NULL;
	unsigned char *A2BYTE=NULL;

	E1BYTE=(unsigned char*) malloc((KAZ_KA_PUBLICKEY_BYTES)*sizeof(unsigned char));
	E2BYTE=(unsigned char*) malloc((KAZ_KA_PUBLICKEY_BYTES)*sizeof(unsigned char));
	A1BYTE=(unsigned char*) malloc((KAZ_KA_PRIVATEKEY_BYTES)*sizeof(unsigned char));
	A2BYTE=(unsigned char*) malloc((KAZ_KA_PRIVATEKEY_BYTES)*sizeof(unsigned char));

	if (!E1BYTE || !E2BYTE || !A1BYTE || !A2BYTE) {
        fprintf(stderr, "KAZ-KA-KEYAGREEMENT: Memory allocation failed.\n");
		ret=-4;
        goto kaz_ka_cleanup;
    }

	memset(E1BYTE, 0, KAZ_KA_PUBLICKEY_BYTES);
	memset(E2BYTE, 0, KAZ_KA_PUBLICKEY_BYTES);
	memset(A1BYTE, 0, KAZ_KA_PRIVATEKEY_BYTES);
	memset(A2BYTE, 0, KAZ_KA_PRIVATEKEY_BYTES);

	for(int i=0; i<KAZ_KA_PUBLICKEY_BYTES; i++) E1BYTE[i]=pk[i];
	for(int i=0; i<KAZ_KA_PUBLICKEY_BYTES; i++) E2BYTE[i]=pk[i+KAZ_KA_PUBLICKEY_BYTES];
	for(int i=0; i<KAZ_KA_PRIVATEKEY_BYTES; i++) A1BYTE[i]=sk[i];
	for(int i=0; i<KAZ_KA_PRIVATEKEY_BYTES; i++) A2BYTE[i]=sk[i+KAZ_KA_PRIVATEKEY_BYTES];
	
	mpz_import(e1, KAZ_KA_PUBLICKEY_BYTES, 1, sizeof(char), 0, 0, E1BYTE);
	mpz_import(e2, KAZ_KA_PUBLICKEY_BYTES, 1, sizeof(char), 0, 0, E2BYTE);
	mpz_import(a1, KAZ_KA_PRIVATEKEY_BYTES, 1, sizeof(char), 0, 0, A1BYTE);
	mpz_import(a2, KAZ_KA_PRIVATEKEY_BYTES, 1, sizeof(char), 0, 0, A2BYTE);
	
	// Compute Shared Key
	mpz_powm(SHARED, e1, a1, N);
	mpz_powm(tmp, e2, a2, N);
	mpz_mul(SHARED, SHARED, tmp);
	mpz_mod(SHARED, SHARED, N);

	// Set kaz_key_agreement={sk}
	size_t SKSIZE=mpz_sizeinbase(SHARED, 16);

	unsigned char *SKBYTE=NULL;
	
	SKBYTE=(unsigned char*) malloc(SKSIZE*sizeof(unsigned char));
	
	if (!SKBYTE) {
        fprintf(stderr, "KAZ-KA-KEYAGREEMENT: Memory allocation failed.\n");
		ret=-4;
        goto kaz_ka_cleanup;
    }
	
	mpz_export(SKBYTE, &SKSIZE, 1, sizeof(char), 0, 0, SHARED);
	memset(sharedkey, 0, KAZ_KA_SHAREDKEY_BYTES);

	//for(int i=0; i<KAZ_KA_GENERAL_BYTES+(KAZ_KA_EPHERMERAL_PUBLIC_BYTES*2); i++) encap[i]=0;

	int je=KAZ_KA_SHAREDKEY_BYTES-1;
	
	for(int i=SKSIZE-1; i>=0; i--){
		sharedkey[je]=SKBYTE[i];
		je--;
	}

	kaz_ka_cleanup:
		mpz_clears(N, e1, e2, a1, a2, SHARED, tmp, NULL);
		free(E1BYTE);
		free(E2BYTE);
		free(A1BYTE);
		free(A2BYTE);
		free(SKBYTE);

	return ret;
}